(*
 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

(*
 * Abstract data structures for representing and performing basic analysis on
 * programs.
 *)
structure Prog =
struct

(* Convenience abbreviations for set manipulation. *)
infix 1 INTER MINUS UNION
val empty_set = Varset.empty
val make_set = Varset.make
val union_sets = Varset.union_sets
fun (a INTER b) = Varset.inter a b
fun (a MINUS b) = Varset.subtract b a
fun (a UNION b) = Varset.union a b

(*
 * A parsed program.
 *
 *   "'a" is a generic type that appears at every node.
 *
 *   "'e" is an expression type that appears once for each expression.
 *
 *   "'m" is a modification type that appears only in modification clauses.
 *
 *   "'c" is meta-data associated with calls.
 *)
datatype ('a, 'e, 'm, 'c) prog =
    Init of ('a * 'm)
  | Modify of ('a * 'e * 'm)
  | Guard of ('a * 'e)
  | Throw of 'a
  | Call of ('a * 'e list * 'e * 'm * 'c)
  | Spec of ('a * 'e)
  | Fail of 'a
  | While of ('a * 'e * ('a, 'e, 'm, 'c) prog)
  | Condition of ('a * 'e * ('a, 'e, 'm, 'c) prog * ('a, 'e, 'm, 'c) prog)
  | Seq of ('a * ('a, 'e, 'm, 'c) prog * ('a, 'e, 'm, 'c) prog)
  | Catch of ('a * ('a, 'e, 'm, 'c) prog * ('a, 'e, 'm, 'c) prog)
  | RecGuard of ('a * ('a, 'e, 'm, 'c) prog)

datatype call_type = DecMeasure | NewMeasure

(* Extract the data associated with the node. *)
fun get_node_data prog =
  case prog of
      Init (a, _) => a
    | Modify (a, _, _) => a
    | Guard (a, _) => a
    | Throw a => a
    | Call (a, _, _, _, _) => a
    | Spec (a, _) => a
    | Fail a => a
    | While (a, _, _) => a
    | Condition (a, _, _, _) => a
    | Seq (a, _, _) => a
    | Catch (a, _, _) => a
    | RecGuard (a, _) => a

(* Merge data payloads of two structurally identical programs. *)
fun zip_progs progA progB =
  case (progA, progB) of
      (Init (a1, m1), Init (a2, m2)) =>
        Init ((a1, a2), (m1, m2))
    | (Modify (a1, e1, m1), Modify (a2, e2, m2)) =>
        Modify ((a1, a2), (e1, e2), (m1, m2))
    | (Guard (a1, e1), Guard (a2, e2)) =>
        Guard ((a1, a2), (e1, e2))
    | (Throw a1, Throw a2) =>
        Throw ((a1, a2))
    | (Call (a1, e1, ee1, m1, c1), Call (a2, e2, ee2, m2, c2)) =>
        Call ((a1, a2), Utils.zip e1 e2, (ee1, ee2), (m1, m2), (c1, c2))
    | (Spec (a1, e1), Spec (a2, e2)) =>
        Spec ((a1, a2), (e1, e2))
    | (Fail a1, Fail a2) =>
        Fail ((a1, a2))
    | (While (a1, e1, body1), While (a2, e2, body2)) =>
        While ((a1, a2), (e1, e2), zip_progs body1 body2)
    | (Condition (a1, e1, lhs1, rhs1), Condition (a2, e2, lhs2, rhs2)) =>
        Condition ((a1, a2), (e1, e2), zip_progs lhs1 lhs2, zip_progs rhs1 rhs2)
    | (Seq (a1, lhs1, rhs1), Seq (a2, lhs2, rhs2)) =>
        Seq ((a1, a2), zip_progs lhs1 lhs2, zip_progs rhs1 rhs2)
    | (Catch (a1, lhs1, rhs1), Catch (a2, lhs2, rhs2)) =>
        Catch ((a1, a2), zip_progs lhs1 lhs2, zip_progs rhs1 rhs2)
    | (RecGuard (a1, body1), RecGuard (a2, body2)) =>
        RecGuard ((a1, a2), zip_progs body1 body2)
    | other =>
        Utils.invalid_input "structurally identical programs" (@{make_string} other)

(* Map the data payloads of a given program. *)
fun map_prog node_fn expr_fn mod_fn call_fn prog =
  case prog of
      Init (a, m) => Init (node_fn a, mod_fn m)
    | Modify (a, e, m) => Modify (node_fn a, expr_fn e, mod_fn m)
    | Guard (a, e) => Guard (node_fn a, expr_fn e)
    | Throw a => Throw (node_fn a)
    | Call (a, e, ee, m, c) => Call (node_fn a, map expr_fn e, expr_fn ee, mod_fn m, call_fn c)
    | Spec (a, e) => Spec (node_fn a, expr_fn e)
    | Fail a => Fail (node_fn a)
    | While (a, e, body) =>
        While (node_fn a, expr_fn e, map_prog node_fn expr_fn mod_fn call_fn body)
    | Condition (a, e, lhs, rhs) =>
        Condition (node_fn a, expr_fn e,
            map_prog node_fn expr_fn mod_fn call_fn lhs,
            map_prog node_fn expr_fn mod_fn call_fn rhs)
    | Seq (a, lhs, rhs) =>
        Seq (node_fn a,
            map_prog node_fn expr_fn mod_fn call_fn lhs,
            map_prog node_fn expr_fn mod_fn call_fn rhs)
    | Catch (a, lhs, rhs) =>
        Catch (node_fn a,
            map_prog node_fn expr_fn mod_fn call_fn lhs,
            map_prog node_fn expr_fn mod_fn call_fn rhs)
    | RecGuard (a, body) =>
        RecGuard (node_fn a, map_prog node_fn expr_fn mod_fn call_fn body)

(* Fold nodes of the program together in pre-order. *)
fun fold_prog node_fn expr_fn mod_fn call_fn prog v =
  case prog of
      Init (a, m) => (node_fn a #> mod_fn m) v
    | Modify (a, e, m) => (node_fn a #> expr_fn e #> mod_fn m) v
    | Guard (a, e) => (node_fn a #> expr_fn e) v
    | Throw a => (node_fn a) v
    | Call (a, e, ee, m, c) =>
        (node_fn a #> fold expr_fn e #> expr_fn ee #> mod_fn m #> call_fn c) v
    | Spec (a, e) => (node_fn a #> expr_fn e) v
    | Fail a => (node_fn a) v
    | While (a, e, body) =>
        (node_fn a #> expr_fn e
            #> fold_prog node_fn expr_fn mod_fn call_fn body) v
    | Condition (a, e, lhs, rhs) =>
        (node_fn a #> expr_fn e
            #> fold_prog node_fn expr_fn mod_fn call_fn lhs
            #> fold_prog node_fn expr_fn mod_fn call_fn rhs) v
    | Seq (a, lhs, rhs) =>
        (node_fn a
            #> fold_prog node_fn expr_fn mod_fn call_fn lhs
            #> fold_prog node_fn expr_fn mod_fn call_fn rhs) v
    | Catch (a, lhs, rhs) =>
        (node_fn a
            #> fold_prog node_fn expr_fn mod_fn call_fn lhs
            #> fold_prog node_fn expr_fn mod_fn call_fn rhs) v
    | RecGuard (a, body) =>
        (node_fn a
            #> fold_prog node_fn expr_fn mod_fn call_fn body) v

(*
 * Perform a liveness analysis on the given program.
 *
 * Each node's data will contain the set of live variables _prior_ to the block
 * being executed.
 *
 * For instance:
 *
 *    Condition (a < 3)          -- [a, b, c] live
 *       Modify (x := b)         -- [b, x] live
 *       Modify (x := c)         -- [c, x] live
 *    Modify (ret := x)          -- [x] live
 *)
local
fun calc_live_vars' term succ_live throw_live =
let
  fun set_from_some x =
    case x of
        NONE => empty_set
      | SOME x => make_set [x]
in
  case term of
      Init (old, written_vars) =>
        Init (old UNION (succ_live MINUS (set_from_some written_vars)), written_vars)
    | Modify (old, read_vars, written_vars) =>
        Modify (old UNION read_vars UNION (succ_live MINUS (set_from_some written_vars)), read_vars, written_vars)
    | Call (old, read_vars, ret_read_vars, written_vars, call_data) =>
        Call (old UNION (union_sets read_vars)
            UNION ret_read_vars
            UNION (succ_live MINUS (set_from_some written_vars)),
            read_vars, ret_read_vars, written_vars, call_data)
    | Guard (old, read_vars) =>
        Guard (old UNION succ_live UNION read_vars, read_vars)
    | Throw _ =>
        Throw throw_live
    | Spec (old, read_vars) =>
        Spec (old UNION read_vars, read_vars)
    | Fail _ =>
        Fail succ_live
    | While (old, read_vars, body) =>
        let
          val new_body = calc_live_vars' body (succ_live UNION old) throw_live
          val body_live = get_node_data new_body
        in
          While (old UNION body_live UNION read_vars, read_vars, new_body)
        end
    | Condition (old, read_vars, lhs, rhs) =>
        let
          val new_lhs = calc_live_vars' lhs succ_live throw_live
          val lhs_live = get_node_data new_lhs
          val new_rhs = calc_live_vars' rhs succ_live throw_live
          val rhs_live = get_node_data new_rhs
        in
          Condition (old UNION lhs_live UNION rhs_live UNION read_vars, read_vars, new_lhs, new_rhs)
        end
    | Seq (old, lhs, rhs) =>
        let
          val new_rhs = calc_live_vars' rhs succ_live throw_live
          val rhs_live = get_node_data new_rhs
          val new_lhs = calc_live_vars' lhs rhs_live throw_live
          val lhs_live = get_node_data new_lhs
        in
          Seq (old UNION lhs_live, new_lhs, new_rhs)
        end
    | Catch (old, lhs, rhs) =>
        let
          val new_rhs = calc_live_vars' rhs succ_live throw_live
          val rhs_live = get_node_data new_rhs
          val new_lhs = calc_live_vars' lhs succ_live rhs_live
          val lhs_live = get_node_data new_lhs
        in
          Catch (old UNION lhs_live, new_lhs, new_rhs)
        end
    | RecGuard (old, body) =>
        let
          val new_body = calc_live_vars' body succ_live throw_live
          val body_live = get_node_data new_body
        in
          RecGuard (old UNION body_live, new_body)
        end
end
in
fun calc_live_vars prog output_vars =
  let
    val init = map_prog (fn _ => empty_set) (fn (_, a, _) => a) (fn x => x) (fn x => x) prog
  in
    Utils.fixpoint (fn x => calc_live_vars' x output_vars empty_set) (op =) init
  end
end


(*
 * Get the variables read by each block of code.
 *
 * Each node's data will contain the set of variables read in the
 * given block.
 *)
fun get_read_vars term =
  case term of
      Init (_, written_vars) =>
        Init (empty_set, written_vars)
    | Modify (_, read_vars, written_vars) =>
        Modify (read_vars, read_vars, written_vars)
    | Call (_, read_vars, ret_read_vars, written_vars, call_data) =>
        Call (union_sets read_vars, read_vars, ret_read_vars, written_vars, call_data)
    | Guard (_, read_vars) =>
        Guard (read_vars, read_vars)
    | Throw _ =>
        Throw empty_set
    | Spec (_, read_vars) =>
        Spec (read_vars, read_vars)
    | Fail _ =>
        Fail empty_set
    | While (_, read_vars, body) =>
        let
          val new_body = get_read_vars body
          val new_reads = get_node_data new_body
        in
          While (new_reads UNION read_vars, read_vars, new_body)
        end
    | Condition (_, read_vars, lhs, rhs) =>
        let
          val new_lhs = get_read_vars lhs
          val new_rhs = get_read_vars rhs
          val new_reads = get_node_data new_lhs UNION get_node_data new_rhs
        in
          Condition (new_reads UNION read_vars, read_vars, new_lhs, new_rhs)
        end
    | Seq (_, lhs, rhs) =>
        let
          val new_lhs = get_read_vars lhs
          val new_rhs = get_read_vars rhs
          val new_reads = get_node_data new_lhs UNION get_node_data new_rhs
        in
          Seq (new_reads, new_lhs, new_rhs)
        end
    | Catch (_, lhs, rhs) =>
        let
          val new_lhs = get_read_vars lhs
          val new_rhs = get_read_vars rhs
          val new_reads = get_node_data new_lhs UNION get_node_data new_rhs
        in
          Catch (new_reads, new_lhs, new_rhs)
        end
    | RecGuard (_, body) =>
        let
          val new_body = get_read_vars body
          val new_reads = get_node_data new_body
        in
          RecGuard (new_reads, new_body)
        end

(*
 * Get the variables modified by each block of code.
 *
 * Each node's data will contain the set of variables modified in the
 * given block.
 *)
fun get_modified_vars term =
let
  (* Union variables, treating "NONE" as the set UNIV. *)
  infix UNION'
  fun (_ UNION' NONE) = NONE
    | (NONE UNION' _) = NONE
    | ((SOME x) UNION' (SOME y)) = SOME (x UNION y)

  (* Create a set from the given list, treating NONE as empty. *)
  fun set_from_some x =
    case x of
        NONE => empty_set
      | SOME x => make_set [x]
in
  case term of
      Init (_, written_vars) =>
        Init (SOME (set_from_some written_vars), written_vars)
    | Modify (_, read_vars, written_vars) =>
        Modify (SOME (set_from_some written_vars), read_vars, written_vars)
    | Call (_, read_vars, ret_read_vars, written_vars, call_data) =>
        Call (SOME (set_from_some written_vars), read_vars, ret_read_vars, written_vars, call_data)
    | Guard (_, read_vars) =>
        Guard (SOME empty_set, read_vars)
    | Throw _ =>
        Throw (SOME empty_set)
    | Spec (_, read_vars) =>
        Spec (NONE, read_vars)
    | Fail _ =>
        Fail (SOME empty_set)
    | While (_, read_vars, body) =>
        let
          val new_body = get_modified_vars body
          val new_modifies = get_node_data new_body
        in
          While (new_modifies, read_vars, new_body)
        end
    | Condition (_, read_vars, lhs, rhs) =>
        let
          val new_lhs = get_modified_vars lhs
          val new_rhs = get_modified_vars rhs
          val new_modifies = get_node_data new_lhs UNION' get_node_data new_rhs
        in
          Condition (new_modifies, read_vars, new_lhs, new_rhs)
        end
    | Seq (_, lhs, rhs) =>
        let
          val new_lhs = get_modified_vars lhs
          val new_rhs = get_modified_vars rhs
          val new_modifies = get_node_data new_lhs UNION' get_node_data new_rhs
        in
          Seq (new_modifies, new_lhs, new_rhs)
        end
    | Catch (_, lhs, rhs) =>
        let
          val new_lhs = get_modified_vars lhs
          val new_rhs = get_modified_vars rhs
          val new_modifies = get_node_data new_lhs UNION' get_node_data new_rhs
        in
          Catch (new_modifies, new_lhs, new_rhs)
        end
    | RecGuard (_, body) =>
        let
          val new_body = get_modified_vars body
          val new_modifies = get_node_data new_body
        in
          RecGuard (new_modifies, new_body)
        end
end

end
